from pecos.xmc.xtransformer.matcher import TransformerMatcher
from pecos.xmc.xtransformer.module import XMCDataset
from pecos.utils.featurization.text.preprocess import Preprocessor
from torch.utils.data import DataLoader, SequentialSampler, BatchSampler
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.utils import to_undirected
import scipy.sparse as ss
import numpy as np
import argparse
from tqdm import tqdm
from copy import deepcopy
import torch.nn as nn
from pecos.utils import torch_util
from transformers import AdamW, AutoConfig, get_linear_schedule_with_warmup
import time
import os
import json
from pecos.xmc.xtransformer.network import ENCODER_CLASSES

from transformers import (
    BertConfig,
    BertModel,
    BertPreTrainedModel,
    BertTokenizer
)

def Triplet_loss(output_i,output_p,output_n):
    EPS = 1.
    Pos_norm = torch.sqrt(((output_i-output_p)**2).sum(axis=-1))
    Neg_norm = torch.sqrt(((output_i-output_n)**2).sum(axis=-1))
    loss = torch.relu(EPS-(Neg_norm-Pos_norm))
    return loss

def Save(matcher,Encoder, save_dir):
    """
    Save the matcher and Encoder
    """
    os.makedirs(save_dir, exist_ok=True)
    # use .module when do parallel training
    encoder_to_save = (
        Encoder.module if hasattr(Encoder, "module") else Encoder
    )

    encoder_dir = os.path.join(save_dir, "text_encoder")
    os.makedirs(encoder_dir, exist_ok=True)
    # this creates config.json, pytorch_model.bin
    encoder_to_save.save_pretrained(encoder_dir)
    # this creates text_tokenizer files
    tokenizer_dir = os.path.join(save_dir, "text_tokenizer")
    os.makedirs(tokenizer_dir, exist_ok=True)
    matcher.text_tokenizer.save_pretrained(tokenizer_dir)
    # this creates text_model
    text_model_dir = os.path.join(save_dir, "text_model")
    torch.save(matcher.text_model, text_model_dir)


class BertForLP(torch.nn.Module):
    def __init__(self, matcher):
        super(BertForLP, self).__init__()
        self.hidden = 768
        self.hidden_dropout_prob = 0.1
        self.bert = deepcopy(matcher.text_encoder.bert)
        self.dropout = nn.Dropout(self.hidden_dropout_prob)
        self.lin = nn.Linear(self.hidden,self.hidden)


    def init_from(self, model):
        self.bert = model.bert

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        with_lin_head = False
    ):

        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )['pooler_output']
    
        if with_lin_head:
            return self.lin(self.dropout(outputs))
        else:
            return outputs

def Gen_ipn_triplet(A):
# Rule: for each node i, sample one positve from A[i].indices and one negative using random
# Assumption: no isolated nodes!
# Can speed up by parallelization! 
    N = A.shape[0]
    Noniso_nodes_idx = np.where(A.max(axis = 0).toarray()[0]>0)[0]
    pos_neg_samples = np.zeros(shape = (len(Noniso_nodes_idx),3), dtype=int)
    for cnt, idx in enumerate(tqdm(Noniso_nodes_idx)):
        pos_neg_samples[cnt,0] = idx
        pos_neg_samples[cnt,1] = np.random.choice(A[idx].indices)

        # This way of constructing negative samples can be improved.
        Flag = True
        while Flag:
            sample = np.random.choice(N)
            if sample in A[idx].indices:
                continue
            else:
                Flag = False

        pos_neg_samples[cnt,2] = sample
    
    return pos_neg_samples

def simple_load(cls, load_dir):

    # load text_encoder
    encoder_dir = os.path.join(load_dir, "text_encoder")
    if not os.path.isdir(encoder_dir):
        raise ValueError(f"text_encoder does not exist at {encoder_dir}")

    with open(os.path.join(encoder_dir, "config.json"), "r", encoding="utf-8") as fin:
        transformer_type = json.loads(fin.read())["model_type"]
    dnn_type = ENCODER_CLASSES[transformer_type]
    encoder_config = dnn_type.config_class.from_pretrained(encoder_dir)
    text_encoder, loading_info = dnn_type.model_class.from_pretrained(
        encoder_dir, config=encoder_config, output_loading_info=True
    )
    if len(loading_info["missing_keys"]) > 0:
        LOGGER.warning(
            "Weights of {} not initialized from pre-trained text_encoder: {}".format(
                text_encoder.__class__.__name__, loading_info["missing_keys"]
            )
        )

    # load text_tokenizer
    tokenizer_dir = os.path.join(load_dir, "text_tokenizer")
    if not os.path.isdir(tokenizer_dir):
        raise ValueError(f"text_tokenizer does not exist at {tokenizer_dir}")
    text_tokenizer = dnn_type.tokenizer_class.from_pretrained(tokenizer_dir)

    # load text_model
    text_model_dir = os.path.join(load_dir, "text_model")
    if os.path.exists(text_model_dir):
        text_model = torch.load(text_model_dir)
    else:
        text_model = TransformerLinearXMCHead(encoder_config)
        LOGGER.warning(
            f"XMC text_model of {text_encoder.__class__.__name__} not initialized from pre-trained model."
        )

    return cls(
        text_encoder,
        text_tokenizer,
        text_model
    )


def main():
    """
    Assume for now we have
    X: inputs, the dim 0 is the instance dim.
    A: adjacency matrix. csr_matrix
    """
    # Global variables goes here
    parser = argparse.ArgumentParser()
    parser.add_argument(
            "--batch_size",
            type=int,
            default=128,
            help="Minibatch size (in terms of labels (edges)!)",
        )
    parser.add_argument(
            "--batch_gen_workers",
            type=int,
            default=64,
            help="Number of workers for text2tensor",
        )
    parser.add_argument(
            "--truncate_length",
            type=int,
            default=128,
            help="Truncate length for raw corpus.",
        )
    parser.add_argument(
            "--max_steps",
            type=int,
            default=20000,
            help="Max steps for training",
        )
    parser.add_argument(
            "--gradient_accumulation_steps",
            type=int,
            default=1,
            help="gradient_accumulation_steps",
        )
    parser.add_argument(
            "--num_train_epochs",
            type=int,
            default=5,
            help="num_train_epochs",
        )
    parser.add_argument(
            "--weight_decay",
            type=float,
            default=0.0,
            help="weight_decay",
        )
    parser.add_argument(
            "--learning_rate",
            type=float,
            default=6e-5,
            help="learning_rate",
        )
    parser.add_argument(
            "--adam_epsilon",
            type=float,
            default=1e-8,
            help="adam_epsilon",
        )
    parser.add_argument(
            "--warmup_steps",
            type=int,
            default=2000,
            help="warmup_steps",
        )
    parser.add_argument(
            "--max_grad_norm",
            type=float,
            default=1.0,
            help="max_grad_norm",
        )
    parser.add_argument(
            "--logging_steps",
            type=int,
            default=1000,
            help="logging_steps",
        )
    parser.add_argument(
            "--save_steps",
            type=int,
            default=20000,
            help="save_steps",
        ) # Choose save_steps=max_steps for SSL setting.
    parser.add_argument(
            "--model_dir",
            type=str,
            required=True,
            help="model_dir for saving the model",
        )
    parser.add_argument(
            "--text_path",
            type=str,
            required=True,
            help="Path for raw text.",
        )
    parser.add_argument(
            "--dataset",
            type=str,
            required=True,
            help="Which dataset",
        )
    args = parser.parse_args()


    

    # # Step 0: Load model and data

    # Download bert base model from hugging face.
    matcher = TransformerMatcher.download_model("bert-base-uncased", 2)
    # Load adjacency matrix
    dataset = PygNodePropPredDataset(name =  args.dataset, root = "./dataset")
    N = dataset[0].num_nodes
    row, col = to_undirected(dataset[0].edge_index, dataset[0].num_nodes)
    A = ss.csr_matrix((torch.ones_like(row),(row, col)),shape=(N,N))
    print("Loaded adjacency matrix A with {} nodes".format(N))

    # Load raw text
    _, corpus = Preprocessor.load_data_from_file(
            args.text_path,
            label_text_path=None,
            text_pos=0,
        )
    print("Loaded {} training sequences".format(len(corpus)))

    # Apply tokenizer to the raw text
    X = matcher.text_to_tensor(
                        corpus,
                        num_workers=args.batch_gen_workers,
                        max_length=args.truncate_length,
                )
    print("Raw text to Tensor X done!")

    # # Step 1: Preparing Encoder, handling data format and training settings.
    # Creating training minibatch. Note that in worst case, if we contain x edges we the minibatch size will be 3x.
    Idx_minibatch = DataLoader(Gen_ipn_triplet(A), batch_size = args.batch_size, shuffle = True)
    # Get encoder from the matcher
    Encoder = deepcopy(matcher.text_encoder.bert)
    # Set up GPU and send the Encoder to GPU
    device, n_gpu = torch_util.setup_device(True)
    Encoder.to(device)
    # multi-gpu eval
    if n_gpu > 1 and not isinstance(Encoder, torch.nn.DataParallel):
        Encoder = torch.nn.DataParallel(Encoder)

    # Use XMC dataset format
    train_data = XMCDataset(
                X["input_ids"],
                X["attention_mask"],
                X["token_type_ids"],
                torch.arange(X["input_ids"].shape[0]),  # instance number
            )

    # compute stopping criteria
    if args.max_steps > 0:
        t_total = args.max_steps
        steps_per_epoch = len(Idx_minibatch) // args.gradient_accumulation_steps
        args.num_train_epochs = args.max_steps // steps_per_epoch + 1
    else:
        steps_per_epoch = len(Idx_minibatch) // args.gradient_accumulation_steps
        t_total = steps_per_epoch * args.num_train_epochs
        
    # Prepare optimizer, disable weight decay for bias and layernorm weights
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in Encoder.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [
                p
                for n, p in Encoder.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    optimizer = AdamW(
        optimizer_grouped_parameters,
        lr=args.learning_rate,
        eps=args.adam_epsilon,
    )
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    logging_steps = args.logging_steps
    max_steps = args.max_steps


    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    total_train_time, logging_elapsed = 0.0, 0.0
    best_matcher_prec = -1
    avg_matcher_prec = 0
    save_cur_model = False
    no_improve_cnt = 0

    # # Step 2: Start training
    Encoder.zero_grad()
    for epoch in range(1, int(args.num_train_epochs) + 1):
        if epoch > 1:
            # Create new training pairs for each epoch.
            Idx_minibatch = DataLoader(Gen_ipn_triplet(A), batch_size = args.batch_size, shuffle = True)
            
        for batch_cnt,ipn_idx in enumerate(Idx_minibatch):
            Encoder.train()
            start_time = time.time()
            batch_data = train_data[ipn_idx] # of size (batch_size,3,feat_dim)
            batch_data = tuple(t.to(device) for t in batch_data)
            
            output_i = Encoder(
                        input_ids=batch_data[0][:,0,:],
                        attention_mask=batch_data[1][:,0,:],
                        token_type_ids=batch_data[2][:,0,:]
                    )['pooler_output']
        
            output_p = Encoder(
                            input_ids=batch_data[0][:,1,:],
                            attention_mask=batch_data[1][:,1,:],
                            token_type_ids=batch_data[2][:,1,:]
                        )['pooler_output']

            output_n = Encoder(
                            input_ids=batch_data[0][:,2,:],
                            attention_mask=batch_data[1][:,2,:],
                            token_type_ids=batch_data[2][:,2,:]
                        )['pooler_output']
            
            loss = Triplet_loss(output_i,output_p,output_n)
            loss = loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            loss.backward()
            
            tr_loss += loss.item()
       
            logging_elapsed += time.time() - start_time
            total_train_time += time.time() - start_time
            if (batch_cnt + 1) % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(
                    Encoder.parameters(), args.max_grad_norm
                )

                optimizer.step()  # perform gradient update
                scheduler.step()  # update learning rate schedule
                optimizer.zero_grad()  # clear gradient accumulation
                
                global_step += 1

                if (logging_steps > 0 and global_step % logging_steps == 0) or global_step==1:
                    cur_loss = (tr_loss - logging_loss) / logging_steps
                    print(
                        "| [{:4d}/{:4d}][{:6d}/{:6d}] | {:4d}/{:4d} batches | ms/batch {:5.4f} | train_loss {:6e} | lr {:.6e}".format(
                            int(epoch),
                            int(args.num_train_epochs),
                            int(global_step),
                            int(t_total),
                            int(batch_cnt),
                            len(Idx_minibatch),
                            logging_elapsed * 1000.0 / logging_steps,
                            cur_loss,
                            scheduler.get_last_lr()[0],
                        )
                    )
                    logging_loss = tr_loss
                    logging_elapsed = 0

                if global_step % args.save_steps == 0:
                    save_cur_model = True
                    if save_cur_model:
                        no_improve_cnt = 0
                        print(
                            "| **** saving model (tr_loss={}) to {} at global_step {} ****".format(
                                cur_loss,
                                args.model_dir,
                                global_step,
                            )
                        )
                        best_matcher_prec = avg_matcher_prec
                        Save(matcher,Encoder,args.model_dir)
                    else:
                        no_improve_cnt += 1
                    print("-" * 89)
                
                if (max_steps > 0 and global_step > max_steps):
                    break
            if (max_steps > 0 and global_step > max_steps):
                break
        if (max_steps > 0 and global_step > max_steps):
            print('Training Done!')
            break

    # Step 3: Generate node features from this trained model and then save it.
    print("Start generating node features!")
    N = A.shape[0]
    Encoder.eval()
    Idx_minibatch = DataLoader(np.arange(N), batch_size = n_gpu*64, shuffle = False)
    embeddings = []
    for batch_cnt,ipn_idx in enumerate(Idx_minibatch):
        batch_data = train_data[ipn_idx] # of size (batch_size,3,feat_dim)
        batch_data = tuple(t.to(device) for t in batch_data)
        
        output = Encoder(
                    input_ids=batch_data[0][:,0,:],
                    attention_mask=batch_data[1][:,0,:],
                    token_type_ids=batch_data[2][:,0,:]
                )['pooler_output']
        embeddings.append(outputs.cpu().numpy())
    embedding = np.concatenate(embeddings, axis=0)

    print("Got node features. Start saving node features")
    np.save(os.path.join(args.model_dir,'Bert_SSL_LinkPred.npy'), embedding)
    print("Embeddings saved.")

if __name__ == "__main__":
    main()